import mindspore as ms
import torch
from mindspore import ops as P

import compare_acc

param = ms.load_checkpoint("/lgsl_data1/huawei_wz/telechat_flm_ckpt/mindspore.ckpt")
torch_param = torch.load("/lgsl_data1/huawei_wz/iter_0000010_HF_FLM2_fix_head_new/pytorch_model.bin")

reshape = P.Reshape()
ms_param = reshape(param["model.layers.0.feed_forward.dense_4h_to_h.weight"], (21824, 8192))
ms_weight = ms_param.asnumpy()
pt_weight = torch_param["transformer.h.0.mlp.c_proj.weight"].to(torch.float32).numpy()
print("start to compare diff...")
compare_acc.compare(ms_weight, pt_weight,
                    save_name_prefix="ms_dense_layer")


